import asyncio
import argparse
import json
import os
import re
import time
import traceback
import aiohttp 
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
from collections import defaultdict
from datetime import datetime
from core.gsm8k_reasoner import Gsm8kReasoner
from core.asdiv_reasoner import ASDivReasoner
from core.svamp_reasoner import SVAMPReasoner
from core.strategyqa_reasoner import StrategyQAReasoner
from core.openbookqa_reasoner import OpenBookQAReasoner

REASONER_MAP = {
    "gsm8k": Gsm8kReasoner,
    "asdiv": ASDivReasoner,
    "svamp": SVAMPReasoner,
    "strategyqa": StrategyQAReasoner,
    "openbookqa": OpenBookQAReasoner
}

async def main():
    parser = argparse.ArgumentParser(description="Reasoning System")
    parser.add_argument("--dataset", type=str, required=True, 
                       choices=[ "gsm8k", "asdiv", "svamp","strategyqa", "openbookqa"],
                       help="Dataset to evaluate")
    parser.add_argument("--start", type=int, default=0, help="Start index")
    parser.add_argument("--end", type=int, default=1, help="End index")
    args = parser.parse_args()
    
    os.makedirs(f"log/{args.dataset}", exist_ok=True)
    
    global_stats = {
        "total_problems": 0,
        "correct_answers": 0,
        "accuracy": 0.0
    }
    all_results = []
    nodes_count_distribution = defaultdict(int)

    ReasonerClass = REASONER_MAP[args.dataset]
    reasoner = ReasonerClass() 
    
    problems = await reasoner.load_problems(args.start, args.end)
    
    for idx, problem in enumerate(problems, args.start):
        print(f"\nProcessing problem {idx}...")
        reasoner = ReasonerClass()

        if args.dataset == "gsm8k":
            result = await reasoner.execute_workflow(problem["question"])
        elif args.dataset == "asdiv":
            result = await reasoner.execute_workflow(problem["text"])
        elif args.dataset == "svamp":
            result = await reasoner.execute_workflow(
                question=problem["question_concat"]
            )
        else: 
            result = await reasoner.execute_workflow(problem)

        prepared_result = reasoner.save_results(result, problem)
        all_results.append(prepared_result)
        
        nodes_created = len(result.get('nodes', {}))
        nodes_count_distribution[nodes_created] += 1

        global_stats["total_problems"] += 1
        if prepared_result["verification"]["is_correct"]:
            global_stats["correct_answers"] += 1

        print(f"Final Answer: {result.get('final_answer', 'None')}")
        print(f"Correct answer: {prepared_result['verification']['correct_answer']}")
        print(f"Verification: {'CORRECT' if prepared_result['verification']['is_correct'] else 'INCORRECT'}")
    
    if global_stats["total_problems"] > 0:
        global_stats["accuracy"] = (
            global_stats["correct_answers"] / global_stats["total_problems"] * 100
        )

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    accuracy1 = global_stats["accuracy"]
    filename = f"log/{args.dataset}/results_{args.start}_{args.end}_{timestamp}_acc{accuracy1:.1f}%.json"
    with open(filename, "w", encoding="utf-8") as f:
        json.dump({
            "results": all_results,
            "stats": global_stats
        }, f, indent=2, ensure_ascii=False)
    
    print(f"\nResults saved to {filename}")
    print(f"Final Accuracy: {global_stats['accuracy']:.2f}%")
    print("\nNodes Created Distribution:")
    for count, num_problems in sorted(nodes_count_distribution.items()):
        print(f"Problems with {count} nodes: {num_problems}")

if __name__ == "__main__":
    asyncio.run(main())